/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file matmul_macro_v220_impl.h
 * \brief
 */
#ifndef IMPL_MATMUL_MATMUL_MACRO_V220_L0CACHE_IMPL_H
#define IMPL_MATMUL_MATMUL_MACRO_V220_L0CACHE_IMPL_H

#include "matmul_macro_v220_intf.h"

namespace AscendC {


// ===========mad template=================/
// Cmatrix type, Amatrix type, Bmatrix type, L0C_using_uniflag, L0C_using_hset
template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, uint16_t UNIFLAG_EN,
    uint16_t GEMV_MODE, bool L0CACHE, bool ISA2B2SHARED>
class MacroMatmul<IMPL, C_T, A_T, B_T, BIAS_T, UNIFLAG_EN, GEMV_MODE, L0CACHE, ISA2B2SHARED, enable_if_t<L0CACHE>> {
public:
    inline __aicore__ MacroMatmul(){};
    inline __aicore__ ~MacroMatmul();
#ifdef ASCENDC_CPU_DEBUG
    // addr
    uint64_t L0A_PING = L0A_PING_D;
    uint64_t L0A_PONG = L0A_PONG_D;
    uint64_t L0B_PING = L0B_PING_D;
    uint64_t L0B_PONG = L0B_PONG_D;
    uint64_t BIAS_PING = BIAS_PING_D;
    uint64_t BIAS_PONG = BIAS_PONG_D;
#else
    constexpr static uint64_t L0A_PING = L0A_PING_D;
    constexpr static uint64_t L0A_PONG = L0A_PONG_D;
    constexpr static uint64_t L0B_PING = L0B_PING_D;
    constexpr static uint64_t L0B_PONG = L0B_PONG_D;
    constexpr static uint64_t BIAS_PING = BIAS_PING_D;
    constexpr static uint64_t BIAS_PONG = BIAS_PONG_D;
#endif
    // args
    uint16_t sAL1M_;
    uint16_t sAL1K_;
    uint16_t sAL1MOffset_;
    uint16_t sAL1KOffset_;
    uint16_t sBL1N_;
    uint16_t sBL1K_;
    uint16_t sBL1NOffset_;
    uint16_t sBL1KOffset_;
    uint16_t sL1BiasOffset_;
    uint16_t sMadM_;
    uint16_t sMadN_;
    uint16_t sMadK_;
    uint16_t sMad0K_;
    uint16_t sL0cInit_; // 0; normal  1:init
    uint16_t sL0cLast_; // 0; normal  1:last
    uint64_t useL0PingPong_;
    // feature map
    constexpr static uint16_t sFmH_ = 1;
    // state
    uint16_t ssl0PingPongFlag_;
    uint16_t ssAl0PingPongFlag_;
    uint16_t ssBl0PingPongFlag_;
    constexpr static uint16_t ssBiasFull_ = 0;
    uint16_t ssBiasPingPongFlag_;
    uint16_t kDirectionAlign_;
    // instance args
    // 0:format(M, K)
    // 1:format(K, M), need set transpose
    uint16_t ssAmatrixTranspose_;
    // 0:format(K, N), use load3dv2 carry
    // 1:format(N, K), use load2d carry
    uint16_t ssBmatrixTranspose_;
    // 0: no bias
    // 1: fp16
    // 2: fp32
    uint16_t biasType_;
    constexpr static uint16_t typeSize_ = sizeof(A_T);
    A_T aScalar_;
    A_T bScalar_;
    // tpipe
    TBuf<TPosition::A2> l0aBuf_;
    TBuf<TPosition::B2> l0bBuf_;
    TBuf<TPosition::C2> biasBuf_;
#ifdef ASCENDC_CPU_DEBUG
    uint64_t pA;
    uint64_t pB;
    uint64_t pBias;
    bool initFlag = false;
#endif
    int32_t cachePosAPing_ { 0 };
    int32_t cachePosAPong_ { 0 };
    int32_t cachePosBPing_ { 0 };
    int32_t cachePosBPong_ { 0 };
    int32_t cacheProcA_ { 0 };
    int32_t cacheProcB_ { 0 };

    inline __aicore__ void Init();
    inline __aicore__ void Release();
    inline __aicore__ void ResetCache();
    template <bool noBias = false, bool noTail = false, bool intraBlockPartSum = false,
            ScheduleType scheduleType = ScheduleType::INNER_PRODUCT, IterateOrder iterateOrder = IterateOrder::UNDEF,
            bool isNormOuter = false>
    inline __aicore__ void Compute(const LocalTensor<A_T> &l1AMatrix, const LocalTensor<B_T> &l1BMatrix,
        const LocalTensor<C_T> &cMatrix, const LocalTensor<BIAS_T> &bias,
	int64_t offsetb = 0, uint8_t subIdx = 0, uint16_t sMadMStep = 0, uint16_t sMadNStep = 0,
        uint32_t posA = 0, uint32_t posB = 0, uint16_t sBaseM = 0, uint16_t sBaseN = 0);
    template <bool noBias = false>
    inline __aicore__ void ComputeWithMdb(const LocalTensor<A_T> &l1AMatrix, const LocalTensor<B_T> &l1BMatrix,
        const LocalTensor<C_T> &cMatrix, const LocalTensor<BIAS_T> &bias, uint64_t kC0Tail, uint64_t kTail,
        uint16_t sMadMStep) {}
    template <bool noBias = false>
    inline __aicore__ void ComputeWithNdb(const LocalTensor<A_T> &l1AMatrix, const LocalTensor<B_T> &l1BMatrix,
        const LocalTensor<C_T> &cMatrix, const LocalTensor<BIAS_T> &bias, uint64_t kC0Tail, uint64_t kTail,
        uint16_t sMadNStep) {}
    inline __aicore__ void InitSetFlag();
    inline __aicore__ void LoadL12L0BFullLoad(const LocalTensor<B_T> &l1B, uint8_t subIdx,
        uint16_t sMad0K, uint16_t sMadN, uint16_t sBL1N, uint16_t sBL1NOffset, uint16_t sBL1KOffset,
	uint16_t offset);
    inline __aicore__ constexpr static uint16_t GetHwK0()
    {
        if constexpr (IsSameType<C_T, float>::value && sizeof(A_T) == sizeof(half)) {
            return 16;
        } else if constexpr (IsSameType<C_T, float>::value && IsSameType<A_T, float>::value) {
            return 8;
        } else if constexpr (IsSameType<A_T, int8_t>::value) {
            return 32;
        } else if constexpr (IsSameType<A_T, int4b_t>::value) {
            return 64;
        }
    }
    inline __aicore__ constexpr static madtype GetMode() {
        if constexpr (IsSameType<C_T, float>::value && sizeof(A_T) == sizeof(half)) {
            return F162F32;
        } else if constexpr (IsSameType<C_T, float>::value && IsSameType<A_T, float>::value) {
            return F322F32;
        } else if constexpr (IsSameType<A_T, int8_t>::value) {
            return S82S32;
        } else if constexpr (IsSameType<A_T, int4b_t>::value) {
            return S42S32;
        } else {
            return F162F32;
        }
    }
    constexpr static madtype mode_ = GetMode();
private:
    inline __aicore__ void LoadL12L0A(uint64_t k_inner, uint64_t aPoskPtr, uint16_t usedK,
        const LocalTensor<A_T> &l1A, LocalTensor<A_T> &l0A);
    inline __aicore__ void LoadL12L0B(uint64_t k_inner, uint16_t usedK,
        const LocalTensor<B_T> &l1B, LocalTensor<B_T> &l0B);
    inline __aicore__ void LoadL12L0ACache(uint32_t posA, uint64_t k_inner, uint64_t aPoskPtr, uint16_t usedK,
        const LocalTensor<A_T> &l1AMatrix, LocalTensor<A_T> &l0a);
    template <bool intraBlockPartSum = false>
    inline __aicore__ void LoadL12L0BCache(uint32_t posB, uint64_t k_inner, int64_t offsetb,
        uint16_t usedK, const LocalTensor<B_T> &l1BMatrix, LocalTensor<B_T> &l0b);
    inline __aicore__ void MmadMacro(const LocalTensor<A_T> &l0A, const LocalTensor<B_T> &l0B,
        const LocalTensor<C_T> &cMatrix, uint16_t mmadK, uint8_t unitFlag, bool l0c_initial);
};

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, uint16_t UNIFLAG_EN,
    uint16_t GEMV_MODE, bool L0CACHE, bool ISA2B2SHARED>
inline __aicore__ MacroMatmul<IMPL, C_T, A_T, B_T, BIAS_T, UNIFLAG_EN, GEMV_MODE, L0CACHE, ISA2B2SHARED,
    enable_if_t<L0CACHE>>::~MacroMatmul()
{
#ifdef ASCENDC_CPU_DEBUG
    if (initFlag) {
        free((__ca__ A_T *)pA);
        free((__cb__ B_T *)pB);
        free((C_T *)pBias);
    }
#endif
}

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, uint16_t UNIFLAG_EN,
    uint16_t GEMV_MODE, bool L0CACHE, bool ISA2B2SHARED>
inline __aicore__ void MacroMatmul<IMPL, C_T, A_T, B_T, BIAS_T, UNIFLAG_EN, GEMV_MODE, L0CACHE, ISA2B2SHARED,
    enable_if_t<L0CACHE>>::MmadMacro(
    const LocalTensor<A_T> &l0A, const LocalTensor<B_T> &l0B, const LocalTensor<C_T> &cMatrix,
    uint16_t mmadK, uint8_t unitFlag, bool l0c_initial)
{
    uint16_t madM = sMadM_;
    if constexpr (GEMV_MODE >= 1) {
        madM = 1;
    } else {
        if (madM == 1) {
            madM = 16;
        }
    }

    MmadParams mmadParams;
    mmadParams.m = madM;
    mmadParams.k = mmadK;
    mmadParams.n = sMadN_;
    mmadParams.unitFlag = unitFlag;
    mmadParams.kDirectionAlign = kDirectionAlign_;
    if (biasType_) {
        mmadParams.cmatrixSource = l0c_initial;
        mmadParams.cmatrixInitVal = false;
    } else {
        mmadParams.cmatrixSource = false;
        mmadParams.cmatrixInitVal = l0c_initial;
    }
    Mmad(cMatrix, l0A, l0B, mmadParams);

    if ((madM / ALIGN_NUM) * (sMadN_ / ALIGN_NUM) < 10) {
        PipeBarrier<PIPE_M>();
    }
}

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, uint16_t UNIFLAG_EN,
    uint16_t GEMV_MODE, bool L0CACHE, bool ISA2B2SHARED>
inline __aicore__ void MacroMatmul<IMPL, C_T, A_T, B_T, BIAS_T, UNIFLAG_EN, GEMV_MODE, L0CACHE, ISA2B2SHARED,
    enable_if_t<L0CACHE>>::LoadL12L0A(
    uint64_t k_inner, uint64_t aPoskPtr, uint16_t usedK,
    const LocalTensor<A_T> &l1A, LocalTensor<A_T> &l0A)
{
    if constexpr (GEMV_MODE == 2) {
        ASSERT(sMadM_ == 1);
        InitConstValueParams initConstValueParams {1, (uint16_t)ConstCeil(sMadK_, BLOCK_CUBE * GetHwK0()), 0, aScalar_};
        InitConstValue(l0A, initConstValueParams);
        return;
    }
    if constexpr (GEMV_MODE == 1) {
        int FracSize = BYTE_PER_FRACTAL / sizeof(B_T);
        int repeat = CeilDiv(usedK, FracSize);
        // aPoskPtr is unit of element
        LoadData2dParams loadDataParams;
        loadDataParams.repeatTimes = repeat;
        loadDataParams.srcStride = 1;
        loadDataParams.dstGap = 0;
        loadDataParams.ifTranspose = 0;
        LoadData(l0A[0], l1A[aPoskPtr], loadDataParams);
        return;
    }
    if (ssAmatrixTranspose_ > 0) {
        // K_axis is m direction, and M_axis is k direction in load3d intrin
        if constexpr (IsSameType<A_T, int8_t>::value) {
            uint16_t sMad0MAlign = CeilAlign(sMadM_, GetHwK0());
            uint16_t l0aloop = sMad0MAlign / GetHwK0();
            uint8_t l0aRepeat = CeilDiv(usedK, GetHwK0());
            uint64_t l0aSrcAddrStride = sAL1K_ * GetHwK0() ;
            uint64_t l0aDstAddrStride = CeilDiv(usedK, GetHwK0()) * GetHwK0() * GetHwK0();

#if __CCE_AICORE__ >= 300
            uint64_t l1aOffset = CeilDiv(sAL1MOffset_, GetHwK0()) * GetHwK0() * GetHwK0() * typeSize_ +
                k_inner * l0aRepeat * GetHwK0() * GetHwK0() * typeSize_;
#else
            uint8_t l0aRepeatOffset = CeilDiv(sMad0K_, GetHwK0());
            uint64_t l1aOffset = CeilDiv(sAL1KOffset_, GetHwK0()) * GetHwK0() * GetHwK0() * typeSize_ +
                k_inner * l0aRepeatOffset * GetHwK0() * GetHwK0() * typeSize_;
#endif
            uint64_t l0aOffset = 0;
            LoadData2dTransposeParams loadData2dTransposeParams;
            loadData2dTransposeParams.startIndex = 0;
            loadData2dTransposeParams.repeatTimes = l0aRepeat;
            loadData2dTransposeParams.srcStride = 1;
            loadData2dTransposeParams.dstGap = 0;
            loadData2dTransposeParams.dstFracGap = (uint16_t)(l0aRepeat - 1);
            loadData2dTransposeParams.addrMode = inc;
            for (uint16_t i = 0; i < l0aloop; ++i) {
                LoadDataWithTranspose(l0A[l0aOffset], l1A[l1aOffset], loadData2dTransposeParams);
                l1aOffset += l0aSrcAddrStride;
                l0aOffset += l0aDstAddrStride;
            }
        } else {
            // format(K, M), K, M need to be 16 aligned for f32
            uint16_t madMAlign = CeilAlign(sMadM_, ALIGN_NUM);
            uint16_t usedKAlign = CeilAlign(usedK, HW_M0);
            uint16_t sAL1MAlign = CeilAlign(sAL1M_, ALIGN_NUM);
            LoadData3DParamsV2Pro loadData3DV2;
            loadData3DV2.channelSize = sAL1MAlign;
            loadData3DV2.extConfig = ((uint64_t)aPoskPtr << 48) | ((uint64_t)sAL1MOffset_ << 32) |
                                   ((uint64_t)usedKAlign << 16) | (uint64_t)madMAlign;
            loadData3DV2.enTranspose = true;
#if __CCE_AICORE__ >= 220 && __CCE_AICORE__ != 310
            if constexpr (IsSameType<A_T, bfloat16_t>::value) {
                LoadData(l0A[0], l1A[0], loadData3DV2);
            } else {
                LoadData<A_T>(l0A[0], l1A[0], loadData3DV2);
            }
#else
            LoadData<A_T>(l0A[0], l1A[0], loadData3DV2);
#endif
        }
    } else {
        // format(M, K), K_axis is k direction, and M_axis is m direction in load3d intrin
        uint16_t madMAlign = CeilAlign(sMadM_, HW_M0);
        uint16_t usedKAlign = CeilAlign(usedK, GetHwK0());
        uint16_t sAL1KAlign = CeilAlign(sAL1K_, GetHwK0());
        LoadData3DParamsV2Pro loadData3DV2;
        loadData3DV2.channelSize = sAL1KAlign;
        loadData3DV2.extConfig = ((uint64_t)sAL1MOffset_ << 48) | ((uint64_t)aPoskPtr << 32) |
                               ((uint64_t)madMAlign << 16) | (uint64_t)usedKAlign;
#if __CCE_AICORE__ >= 220 && __CCE_AICORE__ != 310
        if constexpr (IsSameType<A_T, bfloat16_t>::value) {
            LoadData(l0A[0], l1A[0], loadData3DV2);
        } else {
            LoadData<A_T>(l0A[0], l1A[0], loadData3DV2);
        }
#else
        LoadData<A_T>(l0A[0], l1A[0], loadData3DV2);
#endif
    }
}

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, uint16_t UNIFLAG_EN,
    uint16_t GEMV_MODE, bool L0CACHE, bool ISA2B2SHARED>
inline __aicore__ void MacroMatmul<IMPL, C_T, A_T, B_T, BIAS_T, UNIFLAG_EN, GEMV_MODE, L0CACHE, ISA2B2SHARED,
    enable_if_t<L0CACHE>>::InitSetFlag()
{
    SetFlag<HardEvent::M_MTE1>(EVENT_ID0);
    SetFlag<HardEvent::M_MTE1>(EVENT_ID1);
}

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, uint16_t UNIFLAG_EN,
    uint16_t GEMV_MODE, bool L0CACHE, bool ISA2B2SHARED>
inline __aicore__ void MacroMatmul<IMPL, C_T, A_T, B_T, BIAS_T, UNIFLAG_EN, GEMV_MODE, L0CACHE, ISA2B2SHARED,
    enable_if_t<L0CACHE>>::LoadL12L0BFullLoad(
    const LocalTensor<B_T> &l1B, uint8_t subIdx, uint16_t sMad0K, uint16_t sMadN, uint16_t sBL1N,
    uint16_t sBL1NOffset, uint16_t sBL1KOffset, uint16_t offset)
{
    auto l0b = l0bBuf_.Get<B_T>();
    if ((subIdx) != 0) {
        l0b = l0b[L0BUF_SIZE / 2 / sizeof(B_T)];
    }
    if (ssBmatrixTranspose_ > 0) {
        // SET LOAD2D parameters , loop axis: K or M, or 1
        // k is GetHwK0() aligned for f32
        uint16_t sMad0KAlign = CeilAlign(sMad0K, GetHwK0());
        uint16_t kC0 = sMad0KAlign / GetHwK0();
        uint16_t nFraC0 = CeilDiv(sMadN, HW_N0);
        uint16_t l0bLoop = 1;
        uint64_t l0bSrcAddrStride = 0;
        uint64_t l0bDstAddrStride = 0;
        uint8_t l0bRepeat = kC0 * nFraC0;
        uint16_t l0bSrcstride = 1;
        uint16_t l0bDststride = 0;
 
        if (nFraC0 * HW_N0 == sBL1N) {
            l0bLoop = 1;            // loop=1
        } else if (nFraC0 >= kC0) { // LOOP is K  and repeat is n axis
            l0bLoop = kC0;
            l0bSrcAddrStride = sBL1N * GetHwK0() * typeSize_;
            l0bDstAddrStride = nFraC0 * HW_N0 * GetHwK0() * typeSize_;
            l0bRepeat = nFraC0;
 
            l0bSrcstride = 1;
            l0bDststride = 0;
        } else { // LOOP is N  and repeat is K axis
            l0bLoop = nFraC0;
            l0bSrcAddrStride = HW_N0 * GetHwK0() * typeSize_;
            l0bDstAddrStride = HW_N0 * GetHwK0() * typeSize_;
            l0bRepeat = kC0;
 
            l0bSrcstride = (sBL1N + HW_N0 - 1) / HW_N0;
            l0bDststride = nFraC0 - 1;
        }
        // use load2d for L1_2_L0B
        LoadData2dParams loadDataParams;
        loadDataParams.repeatTimes = l0bRepeat;
        loadDataParams.srcStride = l0bSrcstride;
        loadDataParams.dstGap = l0bDststride;
        loadDataParams.ifTranspose = 0;
        uint64_t l1bOffset = sBL1NOffset * GetHwK0() + sBL1KOffset * sBL1N;
        uint64_t l0bOffset = offset;
        for (uint64_t i = 0; i < l0bLoop; i++) {
            LoadData(l0b[l0bOffset], l1B[l1bOffset], loadDataParams);
            l1bOffset += (l0bSrcAddrStride / typeSize_);
            l0bOffset += (l0bDstAddrStride / typeSize_);
        }
    } else {
        // use load3dv2 for L1_2_L0B
        // n_axis is K direction, need to be 16 aligned
        uint16_t kAlign = CeilAlign(sMadN, ALIGN_NUM);
        uint16_t mPos = sBL1KOffset;
        // channel size need to be 16 aligned
        uint16_t cAlign = CeilAlign(sBL1N, ALIGN_NUM);
        // k_axis is M direction, need to be HW_M0 aligned
        uint16_t mAlign = CeilAlign(sMad0K, HW_M0);
        // StepN need to be aligned
 
            LoadData3DParamsV2Pro loadData3DV2;
            loadData3DV2.channelSize = cAlign;
            loadData3DV2.extConfig = ((uint64_t)mPos << 48) | ((uint64_t)sBL1NOffset << 32) |
                                   ((uint64_t)mAlign << 16) | (uint64_t)kAlign;
            loadData3DV2.fMatrixCtrl = true;
#if __CCE_AICORE__ >= 220 && __CCE_AICORE__ != 310
            if constexpr (IsSameType<B_T, bfloat16_t>::value) {
                LoadData(l0b[offset], l1B[0], loadData3DV2);
            } else {
                LoadData<B_T>(l0b[offset], l1B[0], loadData3DV2);
            }
#else
            LoadData<B_T>(l0b[offset], l1B[0], loadData3DV2);
#endif
    }
}

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, uint16_t UNIFLAG_EN,
    uint16_t GEMV_MODE, bool L0CACHE, bool ISA2B2SHARED>
inline __aicore__ void MacroMatmul<IMPL, C_T, A_T, B_T, BIAS_T, UNIFLAG_EN, GEMV_MODE, L0CACHE, ISA2B2SHARED,
    enable_if_t<L0CACHE>>::LoadL12L0B(
    uint64_t k_inner, uint16_t usedK, const LocalTensor<B_T> &l1B, LocalTensor<B_T> &l0B)
{
    uint16_t sMad0KAlign = CeilAlign(sMad0K_, GetHwK0());
    uint16_t kC0 = sMad0KAlign / GetHwK0();
    if (ssBmatrixTranspose_ > 0) {
        // SET LOAD2D parameters , loop axis: K or M, or 1
        // k is GetHwK0() aligned for f32
        uint16_t nFraC0 = CeilDiv(sMadN_, HW_N0);
        uint16_t l0bLoop = 1;
        uint64_t l0bSrcAddrStride = 0;
        uint64_t l0bDstAddrStride = 0;
        uint8_t l0bRepeat = kC0 * nFraC0;
        uint16_t l0bSrcstride = 1;
        uint16_t l0bDststride = 0;

        if (nFraC0 * HW_N0 == sBL1N_) {
            l0bLoop = 1;            // loop=1
        } else if (nFraC0 >= kC0) { // LOOP is K  and repeat is n axis
            l0bLoop = kC0;
            l0bSrcAddrStride = sBL1N_ * GetHwK0() * typeSize_;
            l0bDstAddrStride = nFraC0 * HW_N0 * GetHwK0() * typeSize_;
            l0bRepeat = nFraC0;

            l0bSrcstride = 1;
            l0bDststride = 0;
        } else { // LOOP is N  and repeat is K axis
            l0bLoop = nFraC0;
            l0bSrcAddrStride = HW_N0 * GetHwK0() * typeSize_;
            l0bDstAddrStride = HW_N0 * GetHwK0() * typeSize_;
            l0bRepeat = kC0;

            l0bSrcstride = (sBL1N_ + HW_N0 - 1) / HW_N0;
            l0bDststride = nFraC0 - 1;
        }
        // use load2d for L1_2_L0B
        LoadData2dParams loadDataParams;
        loadDataParams.repeatTimes = l0bRepeat;
        loadDataParams.srcStride = l0bSrcstride;
        loadDataParams.dstGap = l0bDststride;
        loadDataParams.ifTranspose = 0;
        uint64_t l1bOffset = sBL1NOffset_ * GetHwK0() + sBL1KOffset_ * sBL1N_ +
            k_inner * kC0 * GetHwK0() * sBL1N_;
        uint64_t l0bOffset = 0;
        for (uint64_t i = 0; i < l0bLoop; i++) {
            LoadData(l0B[l0bOffset], l1B[l1bOffset], loadDataParams);
            l1bOffset += (l0bSrcAddrStride / typeSize_);
            l0bOffset += (l0bDstAddrStride / typeSize_);
        }
    } else {
        if constexpr (IsSameType<B_T, int8_t>::value || IsSameType<B_T, int4b_t>::value) {
            // use load2d transpose for L1_2_L0B
            uint16_t sMad0KAlign = CeilAlign(usedK, GetHwK0());
            uint16_t l0bloop = sMad0KAlign / GetHwK0();
            uint16_t l0bSrcstride = CeilDiv(sBL1K_, GetHwK0());
            uint16_t l0bRepeat = CeilDiv(sMadN_, GetHwK0());
            uint64_t l0bSrcAddrStride = GetHwK0() * GetHwK0();
            uint64_t l0bDstAddrStride = CeilDiv(sMadN_, 16) * 16 * GetHwK0();
            uint64_t l1bOffset = sBL1NOffset_ * sBL1K_ * typeSize_ + sBL1KOffset_ * GetHwK0() * typeSize_ +
                k_inner * kC0 * GetHwK0() * GetHwK0() * typeSize_;
            uint64_t l0bOffset = 0;

            LoadData2dTransposeParams loadData2dTransposeParams;
            loadData2dTransposeParams.startIndex = 0;
            loadData2dTransposeParams.repeatTimes = l0bRepeat;
            loadData2dTransposeParams.srcStride = l0bSrcstride;
            loadData2dTransposeParams.dstGap = 1;
            if constexpr (IsSameType<B_T, int4b_t>::value) {
                loadData2dTransposeParams.dstGap = CeilDiv(GetHwK0(), 16) - 1;
            }
            loadData2dTransposeParams.dstFracGap = 0;
            loadData2dTransposeParams.addrMode = inc;

            for (uint64_t i = 0; i < l0bloop; i++) {
                LoadDataWithTranspose(l0B[l0bOffset], l1B[l1bOffset], loadData2dTransposeParams);
                l1bOffset += l0bSrcAddrStride;
                l0bOffset += l0bDstAddrStride;
            }
        } else {
            // use load3dv2 for L1_2_L0B
            // n_axis is K direction, need to be 16 aligned
            uint16_t kAlign = CeilAlign(sMadN_, ALIGN_NUM);
            uint16_t mPos = sBL1KOffset_ + k_inner * sMad0K_;
            // channel size need to be 16 aligned
            uint16_t cAlign = CeilAlign(sBL1N_, ALIGN_NUM);
            // k_axis is M direction, need to be HW_M0 aligned
            uint16_t mAlign = CeilAlign(usedK, HW_M0);
            // StepN need to be aligned
            LoadData3DParamsV2Pro loadData3DV2;
            loadData3DV2.channelSize = cAlign;
            loadData3DV2.extConfig = ((uint64_t)mPos << 48) | ((uint64_t)sBL1NOffset_ << 32) |
                                   ((uint64_t)mAlign << 16) | (uint64_t)kAlign;
            loadData3DV2.fMatrixCtrl = true;
#if __CCE_AICORE__ >= 220 && __CCE_AICORE__ != 310
            if constexpr (IsSameType<B_T, bfloat16_t>::value) {
                LoadData(l0B[0], l1B[0], loadData3DV2);
            } else {
                LoadData<B_T>(l0B[0], l1B[0], loadData3DV2);
            }
#else
            LoadData<B_T>(l0B[0], l1B[0], loadData3DV2);
#endif
        }
    }
}

// initialization
template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, uint16_t UNIFLAG_EN,
    uint16_t GEMV_MODE, bool L0CACHE, bool ISA2B2SHARED>
inline __aicore__ void MacroMatmul<IMPL, C_T, A_T, B_T, BIAS_T, UNIFLAG_EN, GEMV_MODE, L0CACHE, ISA2B2SHARED,
    enable_if_t<L0CACHE>>::Init()
{
    if constexpr (unlikely(UNIFLAG_EN)) {
        SetMMLayoutTransform(0);
    }
#ifdef ASCENDC_CPU_DEBUG
    // allocate 64K L0A space for cpu debug
    pA = (uint64_t)((__ca__ A_T *)malloc(L0AUF_SIZE));
    // allocate 64K L0B space for cpu debug
    pB = (uint64_t)((__cb__ B_T *)malloc(L0BUF_SIZE));
    pBias = (uint64_t)((C_T *)malloc(BIAS_BUF_SIZE));
    initFlag = true;
    L0A_PING += pA;
    L0A_PONG += pA;
    L0B_PING += pB;
    L0B_PONG += pB;
    BIAS_PING += pBias;
    BIAS_PONG += pBias;
#endif
    ssl0PingPongFlag_ = 0;
    ssAl0PingPongFlag_ = 0;
    ssBl0PingPongFlag_ = 0;
    // close D，bias address need taked from Xd[63:32]
    ssBiasPingPongFlag_ = 0;
    ssAmatrixTranspose_ = 0;
    ssBmatrixTranspose_ = 0;
    biasType_ = 0;

    kDirectionAlign_ = 0;


    sL0cInit_ = 1;
    sL0cLast_ = 0;

    GetTPipePtr()->InitBuffer(l0aBuf_, L0AUF_SIZE);
    GetTPipePtr()->InitBuffer(l0bBuf_, L0BUF_SIZE);
    GetTPipePtr()->InitBuffer(biasBuf_, BIAS_BUF_SIZE);
}

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, uint16_t UNIFLAG_EN,
    uint16_t GEMV_MODE, bool L0CACHE, bool ISA2B2SHARED>
inline __aicore__ void MacroMatmul<IMPL, C_T, A_T, B_T, BIAS_T, UNIFLAG_EN, GEMV_MODE, L0CACHE, ISA2B2SHARED,
    enable_if_t<L0CACHE>>::Release()
{
    WaitFlag<HardEvent::M_MTE1>(EVENT_ID0);
    WaitFlag<HardEvent::M_MTE1>(EVENT_ID1);
}

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, uint16_t UNIFLAG_EN,
    uint16_t GEMV_MODE, bool L0CACHE, bool ISA2B2SHARED>
inline __aicore__ void MacroMatmul<IMPL, C_T, A_T, B_T, BIAS_T, UNIFLAG_EN, GEMV_MODE, L0CACHE, ISA2B2SHARED,
    enable_if_t<L0CACHE>>::ResetCache()
{
   cachePosAPing_ = 0;
   cachePosAPong_ = 0;
   cachePosBPing_ = 0;
   cachePosBPong_ = 0;
   cacheProcA_ = 0;
   cacheProcB_ = 0;
}

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, uint16_t UNIFLAG_EN,
    uint16_t GEMV_MODE, bool L0CACHE, bool ISA2B2SHARED>
inline __aicore__ void MacroMatmul<IMPL, C_T, A_T, B_T, BIAS_T, UNIFLAG_EN, GEMV_MODE, L0CACHE, ISA2B2SHARED,
    enable_if_t<L0CACHE>>::LoadL12L0ACache(
    uint32_t posA, uint64_t k_inner, uint64_t aPoskPtr, uint16_t usedK,
    const LocalTensor<A_T> &l1AMatrix, LocalTensor<A_T> &l0a)
{
    bool hitCachePing = posA == cachePosAPing_ ? true : false;
    bool hitCachePong = posA == cachePosAPong_? true : false;

    // update cache ping or pong
    if (hitCachePing) {
        ssAl0PingPongFlag_ = 0;
    } else if (hitCachePong) {
        ssAl0PingPongFlag_ = 1;
    } else {
        ssAl0PingPongFlag_ = cacheProcA_ % Impl::DB_FACTOR == 0 ? 0 : 1;
    }

    if (ssAl0PingPongFlag_) {
        if constexpr (IsSameType<A_T, int4b_t>::value) {
            l0a = l0a[L0AUF_SIZE / sizeof(A_T)];
        } else {
            l0a = l0a[L0AUF_SIZE / Impl::DB_FACTOR / sizeof(A_T)];
        }
    }

    if (cacheProcA_ == 0 && posA == 0) {
        ++cacheProcA_;
        return LoadL12L0A(k_inner, aPoskPtr, usedK, l1AMatrix, l0a);
    }
    if (!hitCachePing && !hitCachePong) {
        LoadL12L0A(k_inner, aPoskPtr, usedK, l1AMatrix, l0a);
        ++cacheProcA_;
        if (ssAl0PingPongFlag_ == 0) {
            cachePosAPing_ = posA;
        } else {
            cachePosAPong_ = posA;
        }
    }
}

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, uint16_t UNIFLAG_EN,
    uint16_t GEMV_MODE, bool L0CACHE, bool ISA2B2SHARED>
template <bool intraBlockPartSum>
inline __aicore__ void MacroMatmul<IMPL, C_T, A_T, B_T, BIAS_T, UNIFLAG_EN, GEMV_MODE, L0CACHE, ISA2B2SHARED,
    enable_if_t<L0CACHE>>::LoadL12L0BCache(
    uint32_t posB, uint64_t k_inner, int64_t offsetb, uint16_t usedK,
    const LocalTensor<B_T> &l1BMatrix, LocalTensor<B_T> &l0b)
{
    bool hitCachePing = posB == cachePosBPing_ ? true : false;
    bool hitCachePong = posB == cachePosBPong_ ? true : false;

    // update cache ping or pong
    if (hitCachePing) {
        ssBl0PingPongFlag_ = 0;
    } else if (hitCachePong) {
        ssBl0PingPongFlag_ = 1;
    } else {
        ssBl0PingPongFlag_ = cacheProcB_ % Impl::DB_FACTOR == 0 ? 0 : 1;
    }

    if (ssBl0PingPongFlag_) {
        if constexpr (IsSameType<B_T, int4b_t>::value) {
            if constexpr (!intraBlockPartSum) {
                l0b = l0b[L0BUF_SIZE / sizeof(B_T)];
            }
        } else {
            if constexpr (!intraBlockPartSum) {
                l0b = l0b[L0BUF_SIZE / Impl::DB_FACTOR / sizeof(B_T)];
            }
        }
    }

    if (cacheProcB_ == 0 && posB == 0) {
        ++cacheProcB_;
        return LoadL12L0B(k_inner, usedK, l1BMatrix, l0b);
    }
    if (!hitCachePing && !hitCachePong) {
        if constexpr (!intraBlockPartSum) {
            LoadL12L0B(k_inner, usedK, l1BMatrix, l0b);
        } else {
            l0b = l0b[offsetb];
        }
        ++cacheProcB_;
        if (ssBl0PingPongFlag_ == 0) {
            cachePosBPing_ = posB;
        } else {
            cachePosBPong_ = posB;
        }
    }
}

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, uint16_t UNIFLAG_EN,
     uint16_t GEMV_MODE, bool L0CACHE, bool ISA2B2SHARED>
template <bool noBias, bool noTail, bool intraBlockPartSum, ScheduleType scheduleType, IterateOrder iterateOrder,
    bool isNormOuter>
inline __aicore__ void MacroMatmul<IMPL, C_T, A_T, B_T, BIAS_T, UNIFLAG_EN, GEMV_MODE, L0CACHE, ISA2B2SHARED,
    enable_if_t<L0CACHE>>::Compute(
    const LocalTensor<A_T> &l1AMatrix, const LocalTensor<B_T> &l1BMatrix, const LocalTensor<C_T> &cMatrix,
    const LocalTensor<BIAS_T> &bias, int64_t offsetb, uint8_t subIdx, uint16_t sMadMStep, uint16_t sMadNStep,
    uint32_t posA, uint32_t posB, uint16_t sBaseM, uint16_t sBaseN)
{
    uint16_t madKC0 = CeilDiv(sMadK_, GetHwK0());
    uint32_t nFraC0 = CeilDiv(sMadN_, HW_N0);
    uint64_t kC0 = sMad0K_ / GetHwK0();
    uint64_t kLoop;
    if constexpr (noTail) {
        kLoop = 1;
    } else {
        kLoop = sMadK_ / sMad0K_;       // loop times of sMad0K_
    }
    uint64_t kC0Tail = madKC0 - kLoop * kC0; // tail block loop times, unit is 16
    uint64_t kTail;
    if constexpr (noTail) {
        kTail = 0;
    } else {
        kTail = sMadK_ - kLoop * sMad0K_;
    }
 
    // m db
    if constexpr (scheduleType == ScheduleType::OUTER_PRODUCT && iterateOrder == IterateOrder::ORDER_N && !isNormOuter) {
        ComputeWithMdb<noBias>(l1AMatrix, l1BMatrix, cMatrix, bias, kC0Tail, kTail, sMadMStep);
        return;
    }
    // n db
    if constexpr (scheduleType == ScheduleType::OUTER_PRODUCT && iterateOrder == IterateOrder::ORDER_M && !isNormOuter) {
        ComputeWithNdb<noBias>(l1AMatrix, l1BMatrix, cMatrix, bias, kC0Tail, kTail, sMadNStep);
        return;
    }

    if (ssAmatrixTranspose_ > 0) {
        if (mode_ == F322F32) {
            kDirectionAlign_ = 1;
        }
        uint16_t wAlign = CeilAlign(sAL1K_, HW_M0);
        Load3DSetFMatrixCal(sFmH_, wAlign, Impl::padList);
    } else {
        // fmatrix w should be 16 aligned
        uint16_t wAlign = CeilAlign(sAL1M_, HW_M0);
        Load3DSetFMatrixCal(sFmH_, wAlign, Impl::padList);
    }

    if (ssBmatrixTranspose_ < 1) {
        uint16_t wAlign = CeilAlign(sBL1K_, HW_M0);
        Load3DSetFMatrixBCal(sFmH_, wAlign, Impl::padList);
    }

    if constexpr (!noBias) {
        if ((biasType_) && (sL0cInit_) && (ssBiasFull_ == 0)) {
            WaitFlag<HardEvent::M_MTE1>(2);
            uint16_t lenBurst = (sMadN_ * biasType_ * 2 + 63) / 64;
            LocalTensor<C_T> biasC2;
            biasC2 = biasBuf_.Get<C_T>();
            if constexpr (!isNormOuter) {
                DataCopy(biasC2, bias[sL1BiasOffset_ * biasType_ * 2], {1, lenBurst, 0, 0});
            } else {
                uint32_t biasOffset = 0;
                if constexpr (scheduleType == ScheduleType::OUTER_PRODUCT && iterateOrder == IterateOrder::ORDER_M) {
                    if ((ssl0PingPongFlag_  & 0x1) != 0) {
                        biasOffset += sBaseN;
                    }
                }
                DataCopy(biasC2, bias[sL1BiasOffset_ * biasType_ * 2 + biasOffset], {1, lenBurst, 0, 0});
            }
            SetFlag<HardEvent::MTE1_M>(2);
            WaitFlag<HardEvent::MTE1_M>(2);
        }
    }

    LocalTensor<A_T> l0a;
    LocalTensor<B_T> l0b;
    for (uint64_t k_inner = 0; k_inner < kLoop; k_inner++) {
        l0a = l0aBuf_.Get<A_T>();
        l0b = l0bBuf_.Get<B_T>();
        if constexpr(intraBlockPartSum) {
            if ((subIdx) != 0) {
                l0b = l0b[L0BUF_SIZE / 2 / sizeof(B_T)];
            }
        }
        WaitFlag<HardEvent::M_MTE1>(ssl0PingPongFlag_ & 0x1);
        // load L0A
        uint64_t aPoskPtr = k_inner * kC0 * GetHwK0() + sAL1KOffset_;
        LoadL12L0ACache(posA, k_inner, aPoskPtr, sMad0K_, l1AMatrix, l0a);
        ++posA;
        // load L0B
        LoadL12L0BCache<intraBlockPartSum>(posB, k_inner, offsetb, sMad0K_, l1BMatrix, l0b);
        ++posB;
        SetFlag<HardEvent::MTE1_M>(ssl0PingPongFlag_ & 0x1);
        WaitFlag<HardEvent::MTE1_M>(ssl0PingPongFlag_ & 0x1);
        // MAD
        bool l0c_initial = (k_inner == 0) && (sL0cInit_);
        uint8_t unitFlag = 0;
        if constexpr (UNIFLAG_EN) {
            if constexpr (intraBlockPartSum) {
                if (subIdx == 1) {
                    unitFlag = ((k_inner == (kLoop - 1)) && (sL0cLast_) && (kTail == 0)) ? 3 : 2;
                }
            } else {
                unitFlag = ((k_inner == (kLoop - 1)) && (sL0cLast_) && (kTail == 0)) ? 3 : 2;
            }
        }
        if constexpr (!isNormOuter) {
            MmadMacro(l0a, l0b, cMatrix, sMad0K_, unitFlag, l0c_initial);
        } else {
            uint32_t l0cOffset = 0;
            if ((ssl0PingPongFlag_  & 0x1) != 0) {
                l0cOffset = sBaseM * sBaseN;
            }
            MmadMacro(l0a, l0b, cMatrix[l0cOffset], sMad0K_, unitFlag, l0c_initial);
        }
        SetFlag<HardEvent::M_MTE1>(ssl0PingPongFlag_ & 0x1);
        if constexpr (!noBias) {
            if ((biasType_) && (l0c_initial) && (ssBiasFull_ == 0)) {
                SetFlag<HardEvent::M_MTE1>(2);
            }
        }
        // update pingpong flag
        ssl0PingPongFlag_ += useL0PingPong_;
    }
    // k  tail
    if constexpr (!noTail) {
        if (kTail != 0) {

            l0a = l0aBuf_.Get<A_T>();
            l0b = l0bBuf_.Get<B_T>();
            WaitFlag<HardEvent::M_MTE1>(ssl0PingPongFlag_ & 0x1);
            uint16_t tailK = kC0Tail * GetHwK0();
            uint64_t aPoskPtr = kLoop * kC0 * GetHwK0() + sAL1KOffset_;
            // load L0A
            LoadL12L0ACache(posA, kLoop, aPoskPtr, tailK, l1AMatrix, l0a);
            // load L0B
            LoadL12L0BCache<intraBlockPartSum>(posB, kLoop, offsetb, tailK, l1BMatrix, l0b);
            SetFlag<HardEvent::MTE1_M>(EVENT_ID0);
            WaitFlag<HardEvent::MTE1_M>(EVENT_ID0);
            // MAD
            bool l0c_initial = (kLoop == 0) && (sL0cInit_);
            uint8_t unitFlag = 0;
            if constexpr (UNIFLAG_EN) {
                if constexpr (intraBlockPartSum) {
                    if (subIdx == 1) {
                        unitFlag = ((sL0cLast_)) ? 3 : 2;
                    }
                } else {
                    unitFlag = sL0cLast_ ? 3 : 2;
                }
            }
            MmadMacro(l0a, l0b, cMatrix, kTail, unitFlag, l0c_initial);
            SetFlag<HardEvent::M_MTE1>(ssl0PingPongFlag_ & 0x1);
            if constexpr (!noBias) {
                if ((biasType_) && (l0c_initial) && (ssBiasFull_ == 0)) {
                    SetFlag<HardEvent::M_MTE1>(2);
                }
            }
            // update pingpong flag
            ssl0PingPongFlag_ += useL0PingPong_;
        }
    }
    if constexpr (!noBias) {
        if ((biasType_) && (sL0cLast_)) {
            ssBiasPingPongFlag_ += 1 - ssBiasFull_;
        }
    }
    
}

} // namespace matmul
#endif